-
Notifications
You must be signed in to change notification settings - Fork 15.1k
Enable pass instrumentation to signal failures. #163126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Jacques Pienaar (jpienaar) ChangesEnables adding instrumentation to pass manager that can track/flag invariants. This would be useful for cases where one some tighter requirements than the general dialects or for a phase of conversion that elsewhere. It would enable making verify also just a regular instrumentation I believe, but also a non-goal as that is a first class concept and baseline for the ops and passes. Would have enabled some of the requirements of https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10 . Full diff: https://github.com/llvm/llvm-project/pull/163126.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h
index 16893c6db87b1..f0b0979a81ee3 100644
--- a/mlir/include/mlir/Pass/Pass.h
+++ b/mlir/include/mlir/Pass/Pass.h
@@ -17,6 +17,7 @@
#include <optional>
namespace mlir {
+class PassInstrumentation;
namespace detail {
class OpToOpPassAdaptor;
struct OpPassManagerImpl;
@@ -334,6 +335,9 @@ class Pass {
/// Allow access to 'passOptions'.
friend class PassInfo;
+
+ /// Allow access to 'signalPassFailure'.
+ friend class PassInstrumentation;
};
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h
index 917bac4b22288..25a8e77be75ee 100644
--- a/mlir/include/mlir/Pass/PassInstrumentation.h
+++ b/mlir/include/mlir/Pass/PassInstrumentation.h
@@ -80,6 +80,8 @@ class PassInstrumentation {
/// name of the analysis that was computed, its TypeID, as well as the
/// current operation being analyzed.
virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {}
+
+ static void signalPassFailure(Pass *pass);
};
/// This class holds a collection of PassInstrumentation objects, and invokes
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 521c7c6be17b6..17ac475b42f4b 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -599,17 +599,20 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
if (pi)
pi->runBeforePass(pass, op);
- bool passFailed = false;
- op->getContext()->executeAction<PassExecutionAction>(
- [&]() {
- // Invoke the virtual runOnOperation method.
- if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
- adaptor->runOnOperation(verifyPasses);
- else
- pass->runOnOperation();
- passFailed = pass->passState->irAndPassFailed.getInt();
- },
- {op}, *pass);
+ bool passFailed = pass->passState->irAndPassFailed.getInt();
+ if (!passFailed) {
+ op->getContext()->executeAction<PassExecutionAction>(
+ [&]() {
+ // Invoke the virtual runOnOperation method.
+ if (auto *adaptor = dyn_cast<OpToOpPassAdaptor>(pass))
+ adaptor->runOnOperation(verifyPasses);
+ else
+ pass->runOnOperation();
+ passFailed = pass->passState->irAndPassFailed.getInt();
+ },
+ {op}, *pass);
+ }
+
// Invalidate any non preserved analyses.
am.invalidate(pass->passState->preservedAnalyses);
@@ -640,10 +643,12 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
// Instrument after the pass has run.
if (pi) {
- if (passFailed)
+ if (passFailed) {
pi->runAfterPassFailed(pass, op);
- else
+ } else {
pi->runAfterPass(pass, op);
+ passFailed = passFailed || pass->passState->irAndPassFailed.getInt();
+ }
}
// Return if the pass signaled a failure.
@@ -1198,6 +1203,8 @@ void PassInstrumentation::runBeforePipeline(
void PassInstrumentation::runAfterPipeline(
std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
+void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); }
+
//===----------------------------------------------------------------------===//
// PassInstrumentor
//===----------------------------------------------------------------------===//
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 7e618811eabf4..86c793384db11 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassInstrumentation.h"
#include "gtest/gtest.h"
#include <memory>
@@ -117,6 +118,103 @@ struct AddSecondAttrFunctionPass
}
};
+/// PassInstrumentation to count pass callbacks and signal pass failures.
+struct TestPassInstrumentation : public PassInstrumentation {
+ int beforePassCallbackCount = 0;
+ int afterPassCallbackCount = 0;
+ int afterPassFailedCallbackCount = 0;
+
+ bool failBeforePass = false;
+ bool failAfterPass = false;
+
+ void runBeforePass(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++beforePassCallbackCount;
+ if (failBeforePass)
+ signalPassFailure(pass);
+ }
+ void runAfterPass(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++afterPassCallbackCount;
+ if (failAfterPass)
+ signalPassFailure(pass);
+ }
+ void runAfterPassFailed(Pass *pass, Operation *op) override {
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+
+ ++afterPassFailedCallbackCount;
+ }
+};
+
+TEST(PassManagerTest, PassInstrumentation) {
+ MLIRContext context;
+ context.loadDialect<func::FuncDialect>();
+ Builder b(&context);
+
+ // Create a module with 1 function.
+ OwningOpRef<ModuleOp> module(ModuleOp::create(UnknownLoc::get(&context)));
+ auto func = func::FuncOp::create(b.getUnknownLoc(), "test_func",
+ b.getFunctionType({}, {}));
+ func.setPrivate();
+ module->push_back(func);
+
+ struct InstrumentationCounts {
+ int beforePass;
+ int afterPass;
+ int afterPassFailed;
+ };
+
+ auto runInstrumentation =
+ [&](bool failBefore,
+ bool failAfter) -> std::pair<LogicalResult, InstrumentationCounts> {
+ // Instantiate and run our pass.
+ auto pm = PassManager::on<ModuleOp>(&context);
+ auto instrumentation = std::make_unique<TestPassInstrumentation>();
+ auto *instrumentationPtr = instrumentation.get();
+ instrumentation->failBeforePass = failBefore;
+ instrumentation->failAfterPass = failAfter;
+ pm.addInstrumentation(std::move(instrumentation));
+ pm.addNestedPass<func::FuncOp>(std::make_unique<AddAttrFunctionPass>());
+ LogicalResult result = pm.run(module.get());
+
+ InstrumentationCounts counts = {
+ .beforePass = instrumentationPtr->beforePassCallbackCount,
+ .afterPass = instrumentationPtr->afterPassCallbackCount,
+ .afterPassFailed = instrumentationPtr->afterPassFailedCallbackCount};
+ return {result, counts};
+ };
+
+ for (bool failBefore : {false, true}) {
+ for (bool failAfter : {false, true}) {
+ auto [result, counts] = runInstrumentation(failBefore, failAfter);
+
+ InstrumentationCounts expected;
+ if (failBefore) {
+ EXPECT_TRUE(failed(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 0, .afterPassFailed = 1};
+ } else if (failAfter) {
+ EXPECT_TRUE(failed(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+ } else {
+ EXPECT_TRUE(succeeded(result))
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ expected = {.beforePass = 1, .afterPass = 1, .afterPassFailed = 0};
+ }
+
+ EXPECT_EQ(counts.beforePass, expected.beforePass)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ EXPECT_EQ(counts.afterPass, expected.afterPass)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ EXPECT_EQ(counts.afterPassFailed, expected.afterPassFailed)
+ << "failBefore=" << failBefore << ", failAfter=" << failAfter;
+ }
+ }
+}
+
TEST(PassManagerTest, ExecutionAction) {
MLIRContext context;
context.loadDialect<func::FuncDialect>();
|
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions cpp,h -- mlir/include/mlir/Pass/Pass.h mlir/include/mlir/Pass/PassInstrumentation.h mlir/lib/Pass/Pass.cpp mlir/unittests/Pass/PassManagerTest.cpp --diff_from_common_commit
View the diff from clang-format here.diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 17ac475b4..9dc947a78 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -613,7 +613,6 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op,
{op}, *pass);
}
-
// Invalidate any non preserved analyses.
am.invalidate(pass->passState->preservedAnalyses);
@@ -1203,7 +1202,9 @@ void PassInstrumentation::runBeforePipeline(
void PassInstrumentation::runAfterPipeline(
std::optional<OperationName> name, const PipelineParentInfo &parentInfo) {}
-void PassInstrumentation::signalPassFailure(Pass *pass) { pass->signalPassFailure(); }
+void PassInstrumentation::signalPassFailure(Pass *pass) {
+ pass->signalPassFailure();
+}
//===----------------------------------------------------------------------===//
// PassInstrumentor
diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp
index 86c793384..50cd8ee1c 100644
--- a/mlir/unittests/Pass/PassManagerTest.cpp
+++ b/mlir/unittests/Pass/PassManagerTest.cpp
@@ -128,21 +128,24 @@ struct TestPassInstrumentation : public PassInstrumentation {
bool failAfterPass = false;
void runBeforePass(Pass *pass, Operation *op) override {
- if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
+ return;
++beforePassCallbackCount;
if (failBeforePass)
signalPassFailure(pass);
}
void runAfterPass(Pass *pass, Operation *op) override {
- if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
+ return;
++afterPassCallbackCount;
if (failAfterPass)
signalPassFailure(pass);
}
void runAfterPassFailed(Pass *pass, Operation *op) override {
- if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>()) return;
+ if (pass->getTypeID() != TypeID::get<AddAttrFunctionPass>())
+ return;
++afterPassFailedCallbackCount;
}
|
| /// current operation being analyzed. | ||
| virtual void runAfterAnalysis(StringRef name, TypeID id, Operation *op) {} | ||
|
|
||
| static void signalPassFailure(Pass *pass); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add documentation for this method, also should this be a static method or an instance method? I know it doesn't have to be an instance method but that would help keep the scope of API exposure slimmer (otherwise, should we just make signalPassFailure public?)
| passFailed = pass->passState->irAndPassFailed.getInt(); | ||
| }, | ||
| {op}, *pass); | ||
| bool passFailed = pass->passState->irAndPassFailed.getInt(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is non-intuitive to me, should be documented.
Enables adding instrumentation to pass manager that can track/flag invariants. This would be useful for cases where one some tighter requirements than the general dialects or for a phase of conversion that elsewhere.
It would enable making verify also just a regular instrumentation I believe, but also a non-goal as that is a first class concept and baseline for the ops and passes.
Would have enabled some of the requirements of https://discourse.llvm.org/t/pre-verification-logic-before-running-conversion-pass-in-mlir/88318/10 .